#include "gl_wrapper.h"
#include <cstdio>
#include <cstring>
#include <cstdlib>

int gLocalSize[3] = {4, 4, 1};

const char* glsl_host_to_device =
  "#version 320 es\n"
  "#define PRECISION highp\n"
  "precision PRECISION float;\n"
  "#define FORMAT rgba32f\n"
  "layout(FORMAT, binding=0) writeonly uniform PRECISION image3D uImage;\n"
  "layout(binding=1) readonly buffer SSBO {\n"
  "    float data[];\n"
  "} uInBuffer;\n"
  "layout(location = 2) uniform int uWidth;\n"
  "layout(location = 3) uniform int uHeight;\n"
  "layout(location = 4) uniform int uChannel;\n"
  "layout (local_size_x = 4, local_size_y = 4, local_size_z = 1) in;\n"
  "void main()\n"
  "{\n"
  "    ivec3 pos = ivec3(gl_GlobalInvocationID);\n"
  "    if (pos.x < uWidth && pos.y < uHeight)\n"
  "    {\n"
  "        vec4 color;\n"
  "        int z = pos.z*4;\n"
  "        color.r = uInBuffer.data[pos.y*uWidth*uChannel + pos.x*uChannel + (z+0)];\n"
  "        color.g = uInBuffer.data[pos.y*uWidth*uChannel + pos.x*uChannel + (z+1)];\n"
  "        color.b = uInBuffer.data[pos.y*uWidth*uChannel + pos.x*uChannel + (z+2)];\n"
  "        color.a = uInBuffer.data[pos.y*uWidth*uChannel + pos.x*uChannel + (z+3)];\n"
  "        imageStore(uImage, pos, color);\n"
  "    }\n"
  "}\n";

const char* glsl_device_to_host =
  "#version 320 es\n"
  "#define PRECISION highp\n"
  "precision PRECISION float;\n"
  "#define FORMAT rgba32f\n"
  "layout(FORMAT, binding=0) readonly uniform PRECISION image3D uImage;\n"
  "layout(binding=1) writeonly buffer destBuffer{\n"
  "    float data[];\n"
  "} uOutBuffer;\n"
  "layout(location = 2) uniform int uWidth;\n"
  "layout(location = 3) uniform int uHeight;\n"
  "layout(location = 4) uniform int uChannel;\n"
  "layout (local_size_x = 4, local_size_y = 4, local_size_z = 1) in;\n"
  "void main()\n"
  "{\n"
  "    ivec3 pos = ivec3(gl_GlobalInvocationID);\n"
  "    if (pos.x < uWidth && pos.y < uHeight)\n"
  "    {\n"
  "        vec4 color = imageLoad(uImage, pos);\n"
  "        int z = pos.z*4;\n"
  "        uOutBuffer.data[pos.y*uWidth*uChannel+pos.x*uChannel+(z+0)] = color.r;\n"
  "        uOutBuffer.data[pos.y*uWidth*uChannel+pos.x*uChannel+(z+1)] = color.g;\n"
  "        uOutBuffer.data[pos.y*uWidth*uChannel+pos.x*uChannel+(z+2)] = color.b;\n"
  "        uOutBuffer.data[pos.y*uWidth*uChannel+pos.x*uChannel+(z+3)] = color.a;\n"
  "    }\n"
  "}\n";


// creaete and bind a shader storage buffer object(shared buffer)
GLuint CreateSSBO(GLuint index, float* pIn, GLuint count) {
  GLuint ssbo;
  glGenBuffers(1, &ssbo);
  glBindBuffer(GL_SHADER_STORAGE_BUFFER, ssbo);
  glBufferData(GL_SHADER_STORAGE_BUFFER, count * sizeof(float), pIn, GL_DYNAMIC_DRAW);
  glBindBufferBase(GL_SHADER_STORAGE_BUFFER, index, ssbo);
  return ssbo;
}

GLuint LoadShader(GLenum shaderType, const char* pSource) {
  GLuint shader = glCreateShader(shaderType);
  if (shader) {
    glShaderSource(shader, 1, &pSource, NULL);
    glCompileShader(shader);
    GLint compiled = 0;
    glGetShaderiv(shader, GL_COMPILE_STATUS, &compiled);
    if (!compiled) {
      GLint infoLen = 0;
      glGetShaderiv(shader, GL_INFO_LOG_LENGTH, &infoLen);
      if (infoLen) {
        char* buf = (char*) malloc(infoLen);
        if (buf) {
          glGetShaderInfoLog(shader, infoLen, NULL, buf);
          fprintf(stderr, "Could not compile shader %d:\n%s\n",
                  shaderType, buf);
          free(buf);
        }
        glDeleteShader(shader);
        shader = 0;
      }
    }
  }
  return shader;
}

GLuint CreateComputeProgram(const char* pComputeSource) {
  GLuint computeShader = LoadShader(GL_COMPUTE_SHADER, pComputeSource);
  if (!computeShader) {
    return 0;
  }

  GLuint program = glCreateProgram();
  if (program) {
    glAttachShader(program, computeShader);
    glLinkProgram(program);
    GLint linkStatus = GL_FALSE;
    glGetProgramiv(program, GL_LINK_STATUS, &linkStatus);
    if (linkStatus != GL_TRUE) {
      GLint bufLength = 0;
      glGetProgramiv(program, GL_INFO_LOG_LENGTH, &bufLength);
      if (bufLength) {
        char* buf = (char*) malloc(bufLength);
        if (buf) {
          glGetProgramInfoLog(program, bufLength, NULL, buf);
          fprintf(stderr, "Could not link program:\n%s\n", buf);
          free(buf);
        }
      }
      glDeleteProgram(program);
      program = 0;
    }
  }
  return program;
}

// copy data from cpu to gpu
void GLCopyHostToDevice(float *input, GLuint textureId, int width, int height, int channel) {
  GLuint computeProgram = CreateComputeProgram(glsl_host_to_device);
  glUseProgram(computeProgram);

  // bind the dest image texture
  glBindImageTexture(0, textureId, 0, GL_TRUE, 0, GL_WRITE_ONLY, GL_RGBA32F);

  // bind the src input data
  CreateSSBO(1, input, ROUND_UP(channel, 4) * width * height);

  // set uniform values
  glUniform1i(2, width);
  glUniform1i(3, height);
  glUniform1i(4, channel);

  int c_4 = UP_DIV(channel, 4);
  glDispatchCompute(UP_DIV(width, gLocalSize[0]), UP_DIV(height, gLocalSize[1]), UP_DIV(c_4, gLocalSize[2]));
  glDeleteProgram(computeProgram);
}

//copy data from gpu to cpu
void GLCopyDeviceToHost(GLuint textureId, float *output, int width, int height, int channel) {
  GLuint computeProgram = CreateComputeProgram(glsl_device_to_host);
  glUseProgram(computeProgram);

  // bind the src image texture
  glBindImageTexture(0, textureId, 0, GL_TRUE, 0, GL_READ_ONLY, GL_RGBA32F);

  // bind the dest output data
  GLuint destId = CreateSSBO(1, NULL, ROUND_UP(channel, 4) * width * height);

  // set uniform values
  glUniform1i(2, width);
  glUniform1i(3, height);
  glUniform1i(4, channel);

  int c_4 = UP_DIV(channel, 4);
  glDispatchCompute(UP_DIV(width, gLocalSize[0]), UP_DIV(height, gLocalSize[1]), UP_DIV(c_4, gLocalSize[2]));

  // memory sync
  glMemoryBarrier(GL_SHADER_STORAGE_BARRIER_BIT);
  glBindBuffer(GL_SHADER_STORAGE_BUFFER, destId);
  auto ptr = glMapBufferRange(GL_SHADER_STORAGE_BUFFER, 0, ROUND_UP(channel, 4) * width * height, GL_MAP_READ_BIT);
  if (ptr != nullptr) {
    ::memcpy(output, ptr, height * width * channel * sizeof(float));
  }

  glUnmapBuffer(GL_SHADER_STORAGE_BUFFER);
  glDeleteProgram(computeProgram);
}

// malloc a gpu buffer using 3D-texture class 
GLuint GLMalloc(int width, int height, int channel) {
  GLuint id;
  glGenTextures(1, &id);
  glBindTexture(GL_TEXTURE_3D, id);
  glTexParameteri(GL_TEXTURE_3D, GL_TEXTURE_MIN_FILTER, GL_NEAREST);
  glTexParameteri(GL_TEXTURE_3D, GL_TEXTURE_MAG_FILTER, GL_NEAREST);
  glTexParameteri(GL_TEXTURE_3D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE);
  glTexParameteri(GL_TEXTURE_3D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE);
  glTexParameteri(GL_TEXTURE_3D, GL_TEXTURE_WRAP_R, GL_CLAMP_TO_EDGE);
  glTexStorage3D(GL_TEXTURE_3D, 1, GL_RGBA32F, width, height, UP_DIV(channel, 4));
  return id;
}